-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FastViT model #26172
Add FastViT model #26172
Conversation
This reverts commit 118dad1.
cc @rafaelpadilla 😉 to keep your eyes on this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @JorgeAV-ai,
Nice work! :)
Tests should be all green before we pass it to a core maintainers review. Also noted few conversations that should be resolved.
Please, let me know if you need help with them.
I added some questions above. I also noticed that some of your comments might be related to an outdated code. Would you mind taking a look again? Thanks 😊 |
if "stem" in name: | ||
name = name.replace("stem", "embeddings.patch_embeddings.projection") | ||
if "conv_kxk" in name: | ||
name = name.replace("conv_kxk", "rbr_conv") | ||
if "conv_scale" in name: | ||
name = name.replace("conv_scale", "rbr_scale") | ||
if "identity" in name: | ||
name = name.replace("identity", "rbr_skip") | ||
if "0.conv" in name: | ||
name = name.replace("0.conv", "conv") | ||
if "0.bn" in name: | ||
name = name.replace("0.bn", "bn") | ||
if "stages" in name: | ||
name = name.replace("stages", "encoder.layer") | ||
if "blocks" in name: | ||
name = name.replace("blocks", "stage_conv") | ||
if "layer_scale.gamma" in name: | ||
name = name.replace("layer_scale.gamma", "layer_scale") | ||
name = name.replace("token_mixer", "token_mixer_block.token_mixer") | ||
if "layer_scale_1.gamma" in name: | ||
name = name.replace("layer_scale_1.gamma", "layer_scale_1") | ||
if "layer_scale_2.gamma" in name: | ||
name = name.replace("layer_scale_2.gamma", "layer_scale_2") | ||
if "token_mixer.norm" in name: | ||
name = name.replace("token_mixer.norm", "token_mixer_block.token_mixer.norm") | ||
if "token_mixer.mixer" in name: | ||
name = name.replace("token_mixer.mixer", "token_mixer_block.token_mixer.mixer") | ||
if "mlp" in name: | ||
name = name.replace("mlp", "convffn") | ||
if ".conv.conv" in name: | ||
name = name.replace("conv.conv", "conv") | ||
if ".conv.bn" in name: | ||
name = name.replace("conv.bn", "bn") | ||
if "proj." in name: | ||
if "token_mixer" not in name: | ||
name_split = name.split(".") | ||
pos = int(name_split[2]) | ||
name_split[2] = str(pos - 1) | ||
if int(name_split[5]) == 0: | ||
name_split[4] = "reparam_large_conv" | ||
else: | ||
name_split[4] = "conv" | ||
name_split.pop(5) # drop the 0 or 1.... | ||
name = ".".join(name_split) | ||
else: | ||
name = name.replace("token_mixer.proj", "token_mixer_block.attention.proj") | ||
if "se.fc1" in name: | ||
name = name.replace("se.fc1", "se.reduce") | ||
if "se.fc2" in name: | ||
name = name.replace("se.fc2", "se.expand") | ||
if "q_bias" in name: | ||
name = name.replace("q_bias", "query.bias") | ||
if "k_bias" in name: | ||
name = name.replace("k_bias", "key.bias") | ||
if "v_bias" in name: | ||
name = name.replace("v_bias", "value.bias") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe (but double check if gives the same data):
if "stem" in name: | |
name = name.replace("stem", "embeddings.patch_embeddings.projection") | |
if "conv_kxk" in name: | |
name = name.replace("conv_kxk", "rbr_conv") | |
if "conv_scale" in name: | |
name = name.replace("conv_scale", "rbr_scale") | |
if "identity" in name: | |
name = name.replace("identity", "rbr_skip") | |
if "0.conv" in name: | |
name = name.replace("0.conv", "conv") | |
if "0.bn" in name: | |
name = name.replace("0.bn", "bn") | |
if "stages" in name: | |
name = name.replace("stages", "encoder.layer") | |
if "blocks" in name: | |
name = name.replace("blocks", "stage_conv") | |
if "layer_scale.gamma" in name: | |
name = name.replace("layer_scale.gamma", "layer_scale") | |
name = name.replace("token_mixer", "token_mixer_block.token_mixer") | |
if "layer_scale_1.gamma" in name: | |
name = name.replace("layer_scale_1.gamma", "layer_scale_1") | |
if "layer_scale_2.gamma" in name: | |
name = name.replace("layer_scale_2.gamma", "layer_scale_2") | |
if "token_mixer.norm" in name: | |
name = name.replace("token_mixer.norm", "token_mixer_block.token_mixer.norm") | |
if "token_mixer.mixer" in name: | |
name = name.replace("token_mixer.mixer", "token_mixer_block.token_mixer.mixer") | |
if "mlp" in name: | |
name = name.replace("mlp", "convffn") | |
if ".conv.conv" in name: | |
name = name.replace("conv.conv", "conv") | |
if ".conv.bn" in name: | |
name = name.replace("conv.bn", "bn") | |
if "proj." in name: | |
if "token_mixer" not in name: | |
name_split = name.split(".") | |
pos = int(name_split[2]) | |
name_split[2] = str(pos - 1) | |
if int(name_split[5]) == 0: | |
name_split[4] = "reparam_large_conv" | |
else: | |
name_split[4] = "conv" | |
name_split.pop(5) # drop the 0 or 1.... | |
name = ".".join(name_split) | |
else: | |
name = name.replace("token_mixer.proj", "token_mixer_block.attention.proj") | |
if "se.fc1" in name: | |
name = name.replace("se.fc1", "se.reduce") | |
if "se.fc2" in name: | |
name = name.replace("se.fc2", "se.expand") | |
if "q_bias" in name: | |
name = name.replace("q_bias", "query.bias") | |
if "k_bias" in name: | |
name = name.replace("k_bias", "key.bias") | |
if "v_bias" in name: | |
name = name.replace("v_bias", "value.bias") | |
for name_from, name_to in ( | |
("stem", "embeddings.patch_embeddings.projection"), | |
("conv_kxk", "rbr_conv"), | |
("conv_scale", "rbr_scale"), | |
("identity", "rbr_skip"), | |
("0.conv", "conv"), | |
("0.bn", "bn"), | |
("stages", "encoder.layer"), | |
("blocks", "stage_conv"), | |
("layer_scale.gamma", "layer_scale"), | |
("token_mixer", "token_mixer_block.token_mixer"), | |
("layer_scale_1.gamma", "layer_scale_1"), | |
("layer_scale_2.gamma", "layer_scale_2"), | |
("token_mixer.norm", "token_mixer_block.token_mixer.norm"), | |
("se.fc1", "se.reduce"), | |
("se.fc2", "se.expand"), | |
("q_bias", "query.bias"), | |
("k_bias", "key.bias"), | |
("v_bias", "value.bias"), | |
("token_mixer.mixer", "token_mixer_block.token_mixer.mixer"), | |
("mlp", "convffn"), | |
): | |
name = name.replace(name_from, name_to) | |
if ".conv.conv" in name: | |
name = name.replace("conv.conv", "conv") | |
if ".conv.bn" in name: | |
name = name.replace("conv.bn", "bn") | |
if "proj." in name: | |
if "token_mixer" not in name: | |
name_split = name.split(".") | |
pos = int(name_split[2]) | |
name_split[2] = str(pos - 1) | |
if int(name_split[5]) == 0: | |
name_split[4] = "reparam_large_conv" | |
else: | |
name_split[4] = "conv" | |
name_split.pop(5) # drop the 0 or 1.... | |
name = ".".join(name_split) | |
else: | |
name = name.replace("token_mixer.proj", "token_mixer_block.attention.proj") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take into account that I followed the structure that was suggested in the template!! I guess they want it in this way...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't necessarily enforce having what is in the template. I would suggest having a single dictionnary that maps old keys to new keys for a more readable code
im = Image.open(requests.get(url, stream=True).raw) | ||
return im |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im = Image.open(requests.get(url, stream=True).raw) | |
return im | |
return Image.open(requests.get(url, stream=True).raw) |
|
||
|
||
def convert_state_dict(orig_state_dict, model): | ||
for key in orig_state_dict.copy().keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for key in orig_state_dict.copy().keys(): | |
for key in orig_state_dict: |
val = orig_state_dict.pop(key) | ||
if "mask" in key: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert_state_dict()
strips all *mask*
elements of orig_state_dict
. Is it intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! :)
Just a small nit.
Hey @ArthurZucker , I did a few iterations and finished the first pass.
Could you please, take it from now? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first general comment is to use full camel for the class name so FastViT
-> FastVit
.
My second question is why the inference
mode has to be passed. Guessing it's for memory efficiency? But then do you intend to push 2 different checkpoints, 1 for inference the other for trianing?
Also the docstring of the classes referencing to the model don't really help as they don't describe what is happening inside. Let's add a scheme if we can to the FastVit.md with what makes it so fast.
Also the SefAttention layer seems pretty standard, could have some copied from here!
@@ -356,6 +356,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h | |||
1. **[ErnieM](https://huggingface.co/docs/transformers/model_doc/ernie_m)** (from Baidu) released with the paper [ERNIE-M: Enhanced Multilingual Representation by Aligning Cross-lingual Semantics with Monolingual Corpora](https://arxiv.org/abs/2012.15674) by Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang. | |||
1. **[ESM](https://huggingface.co/docs/transformers/model_doc/esm)** (from Meta AI) are transformer protein language models. **ESM-1b** was released with the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. **ESM-1v** was released with the paper [Language models enable zero-shot prediction of the effects of mutations on protein function](https://doi.org/10.1101/2021.07.09.450648) by Joshua Meier, Roshan Rao, Robert Verkuil, Jason Liu, Tom Sercu and Alexander Rives. **ESM-2 and ESMFold** were released with the paper [Language models of protein sequences at the scale of evolution enable accurate structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido, Alexander Rives. | |||
1. **[Falcon](https://huggingface.co/docs/transformers/model_doc/falcon)** (from Technology Innovation Institute) by Almazrouei, Ebtesam and Alobeidli, Hamza and Alshamsi, Abdulaziz and Cappelli, Alessandro and Cojocaru, Ruxandra and Debbah, Merouane and Goffinet, Etienne and Heslow, Daniel and Launay, Julien and Malartic, Quentin and Noune, Badreddine and Pannier, Baptiste and Penedo, Guilherme. | |||
1. **[FastViT](https://huggingface.co/docs/transformers/model_doc/fastvit)** (from Apple) released with the paper [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1. **[FastViT](https://huggingface.co/docs/transformers/model_doc/fastvit)** (from Apple) released with the paper [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan. | |
1. **[FastViT](https://huggingface.co/docs/transformers/main/model_doc/fastvit)** (from Apple) released with the paper [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan. |
before a release the doc should make sur to point to main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong, we should not have these changes. Make sure remove this change (the table needs to be updated but not the list of supported model)
FastViT is a hybrid Transformer with some several modifications, such as replacing denses with a factored version, | ||
replace self-attention to large kernel convolutions, with the objective of reducing latency. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be rephrased and gramarly checked
The FastViT model was proposed in [FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189) by Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel and Anurag Ranjan. | ||
FastViT is a hybrid Transformer with some several modifications, such as replacing denses with a factored version, | ||
replace self-attention to large kernel convolutions, with the objective of reducing latency. | ||
The authors claims that FastViT is 3.5× faster than CMT, a recent state-of-the-art hybrid transformer architecture, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing a link to CMT
FastViT is a hybrid Transformer with some several modifications, such as replacing denses with a factored version, | ||
replace self-attention to large kernel convolutions, with the objective of reducing latency. | ||
The authors claims that FastViT is 3.5× faster than CMT, a recent state-of-the-art hybrid transformer architecture, | ||
4.9× faster than EfficientNet, and 1.9× faster than ConvNeXt on a mobile device for the same accuracy on the ImageNet dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think efficient net is part of transformers
we can also link it!
qkv = ( | ||
self.qkv(hidden_states) | ||
.reshape(batch_size, num_patches, 3, self.num_heads, self.num_attention_heads) | ||
.permute(2, 0, 3, 1, 4) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be split in two lines
|
||
# Convert input from shape (batch_size, channels, orig_height, orig_width) | ||
# to the shape (batch_size * patch_area, num_patches, channels) | ||
hidden_states = torch.flatten(hidden_states, start_dim=2).transpose(-2, -1) # B N C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden_states = torch.flatten(hidden_states, start_dim=2).transpose(-2, -1) # B N C | |
hidden_states = torch.flatten(hidden_states, start_dim=2).transpose(-2, -1) |
return hidden_state | ||
|
||
|
||
class FastViTAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems way to similar to the FastVitAttention to require a new class
class FastViTMixer(nn.Module): | ||
""" | ||
This class is an implementation of Metaformer block with RepMixer as token mixer. For more info: `MetaFormer Is | ||
Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ | ||
|
||
""" | ||
|
||
def __init__(self, config: FastViTConfig, stage: str) -> None: | ||
super().__init__() | ||
self.token_mixer = FastViTRepMixer(config, stage) | ||
|
||
def forward(self, hidden_states: torch.tensor) -> torch.tensor: | ||
hidden_states = self.token_mixer(hidden_states) | ||
return hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not need a class
return hidden_states | ||
|
||
|
||
class FastViTCPE(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use the full name for the class?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @JorgeAV-ai are you planning to work further on this PR? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi,Are you still working on it? |
@RobotiX101 As @JorgeAV-ai didn't reply to this Q, I think it's safe to assume the PR is inactive. If you or anyone else would like to try and tackle adding the model we'd be happy to review a PR! |
OK, I will try. |
@amyeroberts Hi, can I take this on, if it's up for grabs to work on and the maintainers are interested in this model to be added? |
@RUFFY-369 The work is open for anyone who wishes to to pick up :) We prioritise reviews based on when PRs are open, rather than claims on issues, as we find this helps avoid things going stale |
@amyeroberts Thank you for your reply. I will take a look at this soon after getting completely done with ProPainter 👍 😄 |
What does this PR do?
Fixes #25526
I have seen that the PR is still open and still no PR has been submitted during these weeks so I have decided to open mine once I finish the model structure + testing
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc: @amyeroberts